import re
import requests
from typing import Any, Dict, Union, List, Tuple
from rest_framework.views import APIView
from rest_framework.decorators import action
from rest_framework import viewsets
from rest_framework.mixins import ListModelMixin
from apscheduler.schedulers.background import BackgroundScheduler
from django_apscheduler.jobstores import register_job

from tzlocal import get_localzone
from django.conf import settings
from django.utils.timezone import localdate, localtime
from django_filters.rest_framework import DjangoFilterBackend
from clogger import logger
from lib.response import *
# from apps.accounts.authentication import Authentication
from rest_framework.response import Response
from apps.vul.models import *
from apps.vul.vul import update_sa as upsa, update_vul as upvul
from apps.vul.vul import fix_cve, get_unfix_cve
from apps.vul.serializer import VulAddrListSerializer, VulAddrModifySerializer
from .async_fetch import FetchHost
from .serializer import SecurityAdvisorySerializer, SecurityAdvisoryDetailSerializer


scheduler = BackgroundScheduler(timezone=f'{get_localzone()}')


@register_job(scheduler, 'cron', id='update_vul', hour=0, minute=0)
def update_vul():
    upvul()


@register_job(scheduler, 'cron', id='update_sa', hour=2, minute=0)
def update_sa():
    upsa()


scheduler.start()


class CommonModelAPIView(viewsets.GenericViewSet):
    def __init__(self, **kwargs: Any) -> None:
        super().__init__(**kwargs)
        self._hosts: List[dict] = self._get_host_list()

    def _get_host_list(self) -> list:
        hostList: List[dict] = FetchHost.get_host_list()
        for host in hostList:
            if host['status'] == 0:
                host['status'] = 'running'
            else:
                host['status'] = 'offline'
        return hostList
    
    def _get_host_by_ip(self, ip: str) -> Union[Dict, None]:
        """通过host ip获取主机信息"""
        _hosts = [host for host in self._hosts if ip in host.values()]
        if len(_hosts) == 1:
            return _hosts[0]
        else:
            return None

    def _get_host_by_hostname(self, hostname: str) -> Union[Dict, None]:
        """通过hostname获取主机信息"""
        _hosts = [host for host in self._hosts if hostname in host.values()]
        if len(_hosts) == 1:
            return _hosts[0]
        else:
            return None
    
    def _get_all_host_ip(self) -> list:
        return [host['ip'] for host in self._hosts]


class VulListView(CommonModelAPIView, ListModelMixin):
    # authentication_classes = [Authentication]
    queryset = SecurityAdvisoryModel.objects.exclude(hosts="").order_by('-created_at')
    serializer_class = SecurityAdvisorySerializer
    lookup_field = 'cve_id'

    def get_vul_list(self, request,  *args, **kwargs):
        """
        Return a list of all users.
        """
        queryset = self.filter_queryset(self.get_queryset())

        page = self.paginate_queryset(queryset)
        if page is not None:
            serializer = self.get_serializer(page, many=True)
            for host in serializer.data:
                host['hosts'] = [self._get_host_by_ip(ip)['hostname'] for ip in host.get('hosts')]
            return self.get_paginated_response(serializer.data)

        serializer = self.get_serializer(queryset, many=True)
        return success(result=serializer.data, total=0)

    def _check_fix_params(self, request) -> Tuple[bool, Any]:
        """
        校验需要修复的漏洞参数
        """
        cve_id_list = request.data.get('cve_id_list', None)
        if cve_id_list is None:
            message='cve_id_list required params'
            return False, message

        if not isinstance(cve_id_list, list):
            return False, 'params cve_id_list not Array'

        return True, cve_id_list

    def post(self, request):
        """
        修复主机漏洞
        """
        status, result = self._check_fix_params(request)
        failed = False
        data = []
        if not status:
            return other_response(code=400, message=result, result={}, success=False)

        for cve in result:
            cve_id = cve["cve_id"]
            hosts = [host['ip'] for host in self._hosts for h_name in cve['hostname'] if h_name == host['hostname']]
            results = fix_cve(hosts, cve_id, user=request.user)
            logger.debug(results)
            sucess_host_list = []
            fail_host_list = []
            for ret in results:
                hostname = ret["host"]
                if ret["ret"]["status"] == 0:
                    sucess_host_list.append(hostname)
                else:
                    failed = True
                    fail_host_list.append({
                        "hosts": hostname,
                        "describe": str(ret["ret"]["result"])
                    })
            data.append({
                "cve_id": cve_id,
                "sucess_host_list": sucess_host_list,
                "fail_host_list": fail_host_list
            })
        if failed:
            return other_response(message='fix cve failed', code=200, result=data)
        else:
            return success(result=data)

    def get_vul_detail(self, request, *args, **kwargs):
        """
        Return a list of all users.
        """
        instance: SecurityAdvisoryModel = self.get_object()
        nodes = [self._get_host_by_ip(ip) for ip in instance.hosts.split(',')]
        ser = SecurityAdvisoryDetailSerializer(instance)
        result = ser.data
        result['hosts'] = nodes
        return success(result=result)


class VulSummaryView(APIView):
    # authentication_classes = [Authentication]

    def get(self, request, format=None):
        """
        """
        sa = SecurityAdvisoryModel.objects.exclude(hosts='')
        sa_cve_count, sa_high_cve_count, sa_affect_host_count = self.get_vul_info(sa)
        unfix_vul = get_unfix_cve().exclude(hosts='')
        vul_cve_count, vul_high_cve_count, vul_affect_host_count = self.get_vul_info(unfix_vul)

        try:
            latest_scan_time = localtime(
                VulJobModel.objects.filter(job_name="update_vul").order_by(
                    '-job_start_time').first().job_start_time).strftime(
                '%Y-%m-%d %Z %H:%M:%S')
        except Exception:
            latest_scan_time = ""
        cvefix_all = SecurityAdvisoryFixHistoryModel.objects.all().values_list("cve_id")
        cvefix_all_count = len(set(cvefix_all))
        cvefix_today = cvefix_all.filter(fixed_at__startswith=localdate().strftime("%Y-%m-%d"))
        cvefix_today_count = len(set(cvefix_today))
        data = {
            "fixed_cve": {
                "affect_host_count": sa_affect_host_count,
                "cve_count": sa_cve_count,
                "high_cve_count": sa_high_cve_count,
                "cvefix_today_count": cvefix_today_count,
                "cvefix_all_count": cvefix_all_count,
                "latest_scan_time": latest_scan_time,
            },

            "unfixed_cve": {
                "affected_host_number": vul_affect_host_count,
                "cve_number": vul_cve_count,
                "high_cve_number": vul_high_cve_count,
            }
        }
        return success(result=data)

    def get_vul_info(self, queryset):
        cve_count = len(set([cve[0] for cve in queryset.values_list("cve_id")]))
        high_cve_count = len(set([cve[0] for cve in queryset.filter(vul_level='high').values_list("cve_id")]))
        affect_host = []
        for cve in queryset:
            affect_host.extend(cve.hosts.split(','))
        affect_host_count = len(set(affect_host))
        return cve_count, high_cve_count, affect_host_count


class SaFixHistListView(APIView):
    # authentication_classes = [Authentication]

    def get(self, request, format=None):
        sa_fix_hist = SecurityAdvisoryFixHistoryModel.objects.exclude(hosts="")
        data = [{"id": fix_obj.id,
                 "cve_id": fix_obj.cve_id,
                 "fixed_time": fix_obj.fixed_at,
                 "fix_user": fix_obj.created_by,
                 "status": fix_obj.status,
                 "vul_level": fix_obj.vul_level} for fix_obj in sa_fix_hist]
        return success(result=data)


class SaFixHistDetailsView(APIView):
    # authentication_classes = [Authentication]
    def __init__(self, **kwargs: Any) -> None:
        super().__init__(**kwargs)
        self._hosts = FetchHost.get_host_list()

    def get_cve2host_details(self, sa_fix_host_obj: SaFixHistToHost):
        _hosts = [host for host in self._hosts if sa_fix_host_obj.hosts in host.values()]
        if len(_hosts) == 0: return {}
        host = _hosts[0]
        return {
            "hostname": host.get('hostname'),
            "ip": host.get('ip'),
            "created_by": host.get('created'),
            "created_at": host.get('created_at'),
            "host_status": 'running' if host.get('status') == 0 else 'offline',
            "status": sa_fix_host_obj.status,
            "details": str(sa_fix_host_obj.details),
        }

    def get(self, request, pk, format=None):
        sa_fix_hist_details = SaFixHistToHost.objects.filter(sa_fix_hist_id=pk)

        if not sa_fix_hist_details:
            return other_response(message=f'pk: {pk} not found!', code=400, success=False)
        data = [self.get_cve2host_details(detail_obj) for detail_obj in sa_fix_hist_details]
        cve_id = sa_fix_hist_details.first().sa_fix_hist.cve_id

        for item in range(len(data)):
            data[item]["id"] = item + 1
        return success(result={
            "cve_id": cve_id,
            "hosts_datail": data
        })


class SaFixHistDetailHostView(APIView):
    # authentication_classes = [Authentication]
    def __init__(self, **kwargs: Any) -> None:
        super().__init__(**kwargs)
        self._hosts = FetchHost.get_host_list()

    def get(self, request, pk, hostname, format=None):
        _hosts = [host for host in self._hosts if hostname in host.values()]
        if len(_hosts) == 0:
            return other_response(message=f'{hostname} not fount!', code=400, success=False)
        else:
            sa_fix_hist_details_host = SaFixHistToHost.objects.filter(
                sa_fix_hist_id=pk, hosts__contains=_hosts[0]['ip']
            ).first()
            if sa_fix_hist_details_host:
                data = {
                    "hostname": hostname,
                    "status": sa_fix_hist_details_host.status,
                    "details": str(sa_fix_hist_details_host.details),
                }
                return success(result=data)
            else:
                return other_response(message=f'pk: {pk} and hostname: {hostname} not found', code=400, success=False)


class UpdateSaView(CommonModelAPIView):
    interval_time = settings.INTERVAL_TIME
    # authentication_classes = [Authentication]

    def _check_has_hosts(self) -> Any:
        hosts = self._get_all_host_ip()
        return hosts if len(hosts) > 0 else False

    def _check_last_sacn_time(self) -> bool:
        instance = VulJobModel.objects.filter(job_name="update_vul")\
            .order_by('-job_start_time').first()
        if instance is None: 
            return True
        else:
            if (localtime() - instance.job_end_time).seconds < self.interval_time:
                return False
            else:
                return True

    def scan(self, request):
        """
        检测最近更新时间，如果小于时间间隔，则直接返回成功
        """
        # step 1 check node count if not node then cancel scan
        hosts = self._check_has_hosts()
        if not hosts:
            return success(
                message="forbidden",
                result='current not node'
            )

        # step 2 check whether the last scan time exceeds the current time by ten minutes
        if not self._check_last_sacn_time():
            return success(
                message="forbidden",
                result="The data has been updated recently,no need to update it again"
            )
        # step 3  update flaw database
        upvul()
        # step 4 update security advisory
        upsa(hosts)
        return success(result="Update security advisory data")


class VulAddrViewSet(viewsets.ModelViewSet):
    # authentication_classes = [Authentication]
    queryset = VulAddrModel.objects.all()
    serializer_class = VulAddrListSerializer
    filter_backends = [DjangoFilterBackend]
    filterset_fields = ['name']

    def get_serializer_class(self):
        if self.request.method == "GET":
            return VulAddrListSerializer
        else:
            return VulAddrModifySerializer

    def list(self, request, *args, **kwargs):
        queryset = self.filter_queryset(self.get_queryset())
        if not queryset:
            return success([], total=0)
        return super().list(request, *args, **kwargs)

    def retrieve(self, request, *args, **kwargs):
        response = super().retrieve(request, *args, **kwargs)
        return success(result=response.data)

    def create(self, request, *args, **kwargs):
        serializer = self.get_serializer(data=request.data)
        serializer.is_valid(raise_exception=True)
        self.perform_create(serializer)
        ser = VulAddrListSerializer(serializer.instance, many=False)
        return success(ser.data, message="新增成功")

    def update(self, request, *args, **kwargs):
        super().update(request, *args, **kwargs)
        return success(result={}, message="修改成功")

    def destroy(self, request, *args, **kwargs):
        super().destroy(request, *args, **kwargs)
        return success(result={}, message="删除成功")

    @action(detail=False, methods=['post'])
    def test_connect(self, request, *args, **kwargs):
        body = request.data
        url, method, headers, params, payload, auth = self.get_req_arg(body)
        req = requests.Request(method, url, headers=headers, data=payload, params=params, auth=auth)
        prepped = req.prepare()
        data = {"request": self.get_req_struct(prepped),
                "status": self.get_resp_result(prepped)}
        return success(result=data, message="")

    @staticmethod
    def get_req_arg(body):
        headers = body.get("headers", {})
        if "User-Agent" not in headers:
            headers[
                "User-Agent"] = "Mozilla/5.0 (X11; Linux x86_64) Chrome/99.0.4844.51"

        if body.get("authorization_body") and body.get("authorization_type").lower() == "basic":
            authorization_body = body.get("authorization_body")
            auth = (authorization_body["username"], authorization_body["password"])
        else:
            auth = ()

        for i in VulAddrModel.REQUEST_METHOD_CHOICES:
            if i[0] == body.get("method"):
                method = i[1]
                break

        return body.get("url"), method, headers, body.get("params"), body.get("body"), auth

    @staticmethod
    def get_req_struct(req):
        req_struct = '{}\n\n{}\n\n{}'.format(
            req.method + ' ' + req.url,
            '\n'.join('{}: {}'.format(k, v) for k, v in req.headers.items()),
            req.body,
        )
        return req_struct

    @staticmethod
    def get_resp_result(req):
        s = requests.Session()

        try:
            resp_status = s.send(req).status_code
            if status.is_success(resp_status) or status == status.HTTP_304_NOT_MODIFIED:
                msg = f"Status Code: {resp_status} OK"
            else:
                msg = f"Status Code: {resp_status} ERROR"
            return msg
        except Exception as e:
            msg = f"Status Code: ERROR({e})"
            return msg


class HealthViewset(viewsets.GenericViewSet):
    authentication_classes = []
    def health_check(self, request, *args, **kwargs):
        return success(result={})